{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Omni-Math Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Filter for non-proof problems."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('../train/omni_math.json', 'r') as f:\n",
    "    omni = json.load(f)\n",
    "len(omni)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Filter for non-proof problems.\n",
    "from copy import deepcopy\n",
    "from tqdm import tqdm\n",
    "from concurrent.futures import ProcessPoolExecutor, as_completed\n",
    "from pprint import pprint\n",
    "\n",
    "from deepscaler.utils import call_gemini_llm\n",
    "from deepscaler.system_prompts import FILTER_PROOF_PROMPT\n",
    "\n",
    "def filter_proofs(idx, entry):\n",
    "    # 1) Get the problem text\n",
    "    problem_text = entry['problem']\n",
    "    solution_text = entry['answer']\n",
    "    # 2) Call Gemini LLM\n",
    "    output_str = call_gemini_llm(f'Problem: {problem_text} \\n\\n Answer: {solution_text}', \n",
    "                                 system_prompt=FILTER_PROOF_PROMPT, temperature=0.8, n=4)\n",
    "    if not output_str:\n",
    "        return idx, entry\n",
    "    for output in output_str:\n",
    "        if '[[1]]' in output:\n",
    "            return idx, entry\n",
    "    pprint(problem_text)\n",
    "    pprint(solution_text)\n",
    "    pprint(output_str[0])\n",
    "    return idx, {}\n",
    "\n",
    "data = deepcopy(omni)\n",
    "\n",
    "with ProcessPoolExecutor(max_workers=32) as executor:\n",
    "    # 1) Submit all jobs to the executor\n",
    "    futures = [executor.submit(filter_proofs, f_idx, entry) for f_idx, entry in enumerate(data)]\n",
    "\n",
    "# 2) Process them as they complete, using tqdm for a progress bar\n",
    "for future in tqdm(as_completed(futures), total=len(futures), desc=\"Processing entries\"):\n",
    "    # Get the result for each completed future\n",
    "    idx, result = future.result()\n",
    "    data[idx] = result\n",
    "data = [d for d in data if d]\n",
    "# Save final list as json\n",
    "with open(\"omni_math.json\", \"w\") as f:\n",
    "    json.dump(data, f, indent=2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deepscaler",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
